# Bidirectional Spike Distillation (BSD) - Supplementary Code

## ⚙️ Environment Setup

### Prerequisites

Before running the code, please set up the environment properly:

#### 1. Install SpikingJelly (IMPORTANT!)

**DO NOT use `pip install spikingjelly`!**

The PyPI version (0.0.0.0.14) is outdated and lacks critical features required by this code. Please clone the latest version from GitHub into the current directory:

```bash
# Navigate to this directory (Code Supplementary)
cd Code\ Supplementary/

Clone the latest SpikingJelly from GitHub into the current directory

# Your directory structure should look like:
# Code Supplementary/              ← Current directory
# ├── README.md                     ← This file
# ├── main_spiking_bsd_train_cnn.py
# ├── models/
# ├── utils/
# ├── dataset/
# ├── yaml/
# └── spikingjelly/                 ← Clone here (same level as main script)
#     ├── spikingjelly/
#     ├── setup.py
#     └── README.md
```

**Why is this necessary?**
- The code uses `sys.path.append(os.path.join(os.path.dirname(__file__), 'spikingjelly'))` to import the local SpikingJelly
- We require features from the latest development version
- The stable PyPI release (0.0.0.0.14) is significantly behind the GitHub repository

#### 2. Install Other Dependencies

```bash
pip install torch torchvision pyyaml tqdm numpy
```

#### 3. Verify Installation

```bash
# Test that SpikingJelly is correctly imported
python -c "from spikingjelly.activation_based import functional; print('✅ SpikingJelly loaded successfully!')"
```

---

## 🚀 Quick Start

### Running the Training Script

To train the BSD model on CIFAR-10 with the default configuration:

```bash
python main_spiking_bsd_train_cnn.py
```

## 🔑 Core Mechanism: Local Learning with Gradient Detachment

### Overview

BSD uses `detach_grad=True` to enforce **locality** in learning. This prevents gradient flow between layers, enabling each layer to learn independently using only local signals from the bidirectional network.

### Implementation Details

#### 1. Training Forward Pass (Lines 1296-1316 in `main_spiking_bsd_train_cnn.py`)

```python
# ====================================================================
# FORWARD PASS with LOCAL LEARNING (BSD Algorithm)
# ====================================================================
# CRITICAL: detach_grad=True enforces LOCALITY in learning
# - Stops gradient flow between layers during forward propagation
# - Each layer learns independently using only local signals
# - This is a core principle of Bidirectional Spike Distillation (BSD)
# - Enables biologically plausible learning without global backprop
# ====================================================================
if batch_idx == 0:
    if args.use_prelif_for_loss:
        activations = model(data_fw.to(args.device),
                          detach_grad=True,              # ← LOCAL LEARNING
                          return_spikes_for_stats=False,
                          use_prelif_for_loss=True)
    else:
        activations = model(data_fw.to(args.device),
                          detach_grad=True,              # ← LOCAL LEARNING
                          return_spikes_for_stats=False,
                          use_prelif_for_loss=False)
```

#### 2. Training Backward Pass (Lines 1319-1339 in `main_spiking_bsd_train_cnn.py`)

```python
# ====================================================================
# BACKWARD PASS with LOCAL LEARNING (BSD Algorithm)
# ====================================================================
# CRITICAL: detach_grad=True enforces LOCALITY in learning
# - Stops gradient flow between layers during backward propagation
# - Backward network learns independently from forward network
# - Creates symmetric bidirectional learning without weight transport
# - This is essential for biologically plausible credit assignment
# ====================================================================
if batch_idx == 0:
    if args.use_prelif_for_loss:
        signals = model.reverse(data_bw.to(args.device),
                              detach_grad=True,          # ← LOCAL LEARNING
                              return_spikes_for_stats=False,
                              use_prelif_for_loss=True)
    else:
        signals = model.reverse(data_bw.to(args.device),
                              detach_grad=True,          # ← LOCAL LEARNING
                              return_spikes_for_stats=False,
                              use_prelif_for_loss=False)
```

#### 3. Layer-Level Implementation (Lines 142-212 in `models/spiking_cnn_layers.py`)

Each spiking layer implements gradient detachment at the layer boundary:

```python
def forward(self, x, detach_grad=False, act=True, return_readout=False, return_prelif=False):
    """
    Forward pass with optional gradient detachment

    Args:
        detach_grad: Whether to detach gradients for local learning (default: False)
                    When True, gradients are blocked from backpropagating to previous layers,
                    ensuring locality in the learning algorithm (key for BSD/local learning)
    """
    # IMPORTANT: Gradient detachment for local learning
    # When detach_grad=True, we stop gradient flow from this layer to previous layers.
    # This enforces locality: each layer updates independently based only on local signals.
    # This is a critical component of the BSD (Bidirectional Spike Distillation) algorithm.
    if detach_grad:
        x = x.detach()    # ← BLOCKS gradient flow to previous layers

    # Forward pass following original structure
    if self.fw_bn == 0:
        x = self.forward_layer(x)
    elif self.fw_bn == 1:
        x = self.forward_layer(self.forward_bn(x))
    elif self.fw_bn == 2:
        x = self.forward_bn(self.forward_layer(x))

    # Store pre-LIF value for potential return
    pre_lif = x.clone() if return_prelif else None

    # Apply LIF activation - this gives spikes
    spike_features = self.forward_lif(x)

    # Return spike features (no gradient flow to previous layers when detach_grad=True)
    return spike_features
```

### Why This Matters

**Traditional Backpropagation:**
```
Layer 1 ←─ gradient ← Layer 2 ←─ gradient ← Layer 3 ←─ gradient ← Loss
   ↑                      ↑                      ↑
   └──────────────────────┴──────────────────────┘
        Global gradient flow (non-local)
```

**BSD with Gradient Detachment:**
```
Layer 1 ←─ local signal (forward features ↔ backward features)
   ⊗ (detached)
Layer 2 ←─ local signal (forward features ↔ backward features)
   ⊗ (detached)
Layer 3 ←─ local signal (forward features ↔ backward features)

Each layer learns independently using only LOCAL signals!
```

## 🎯 Loss Function: ReCo (Relaxed Contrastive) Loss

### Configuration

The BSD algorithm uses ReCo (Relaxed Contrastive Learning) loss to align forward and backward network features:

```yaml
# ReCo (Relaxed Contrastive) Loss Configuration
three_term_use_relaxed_contrastive: true  # Enable ReCo loss (instead of standard InfoNCE)
three_term_reco_lambda: 0.6               # λ parameter for ReCo loss (balances positive/negative terms)
```